"""
下载大模型及向量嵌入模型
"""
import platform
import os
os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"

from config import config
from pathlib import Path
import modelscope
import huggingface_hub


def download_model(model_name: str) -> str:
    """下载模型"""
    os_name = platform.system().lower()
    root_dir = config['model_dir'][os_name]
    Path(root_dir).mkdir(exist_ok=True)

    repo_id = config[model_name]['repo_id']
    local_dir = (Path(root_dir) / model_name).as_posix()

    from_site = config[model_name]['from']
    if from_site == 'hf-mirror':
        huggingface_hub.snapshot_download(repo_id, local_dir=local_dir,
                                          token="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")
    elif from_site == 'modelscope':
        modelscope.snapshot_download(repo_id, local_dir=local_dir)

    return local_dir


if __name__ == '__main__':
    model_dir = download_model('Comet')